#!/usr/bin/python
# This file is part of KEmuFuzzer.
# 
# KEmuFuzzer is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
# 
# KEmuFuzzer is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along with
# KEmuFuzzer.  If not, see <http://www.gnu.org/licenses/>.

# KEmuFuzzer
# Test-case template and instance
#
# <testcase ring="0">
#   <ring0>
# mov KEF_INT, %eax;
# mov %eax, %ebx;
# add $0x100, %ebx;
#   </ring0>
#
#   <ring3>
# mov %eax, %ebx;
# KEF_PREFIX add $0x100, %ebx;
#   </ring3>
# </testcase>
#

import sys, random, subprocess, tempfile, os, re
import xml.dom.minidom
from elf import Elf
from tc import TestCase

KERNEL = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), "kernel/kernel")
RINGS_SELECTORS = {
    0: {'cs': (0x68 | 0), 'ds': (0xb0 | 0), 'ss': (0x70 | 0)},
    1: {'cs': (0xc0 | 1), 'ds': (0x88 | 1), 'ss': (0x90 | 1)},
    2: {'cs': (0x98 | 2), 'ds': (0xa0 | 2), 'ss': (0xa8 | 2)},
    3: {'cs': (0x78 | 3), 'ds': (0xb8 | 3), 'ss': (0x80 | 3)},
}
for k,v in RINGS_SELECTORS.iteritems():
    v['es'] = v['fs'] = v['gs'] = v['ds']

def mutator_int(instr):
    instrs = []
    for i in ["$0x0", "$0xFF", "$0x80", "$0xFFFF", "$0x80000000", "$0xFFFFFFFF"]:
        instrs += [instr.replace("KEF_INT", i)]
    return instrs

# Guess the specifier of an immediate operand
def guess_specifier(imm):
    assert imm.startswith("$0x"), "[!] Not an immediate hex operand: '%s'" % imm
    imm = imm.replace("$0x", "")
    if len(imm) == 1 or len(imm) == 2:
        spec = "b"
    elif len(imm) == 3 or len(imm) == 4:
        spec = "w"
    elif 5 <= len(imm) <= 8:
        spec = "l"
    elif 9 <= len(imm) <= 16:
        spec = "q"
    else:
        assert False, "[!] Unknown size-specifier for hex operand '%s'" % imm

    return spec

# *Specified* integer mutator. Mutates an integer immediate, but it also
# *augments the opcode with the appropriate size-specifier.
def mutator_spec_int(instr):
    instrs = []
    for i in ["$0x0", "$0xFF", "$0x80", "$0xFFFF", "$0x80000000", "$0xFFFFFFFF"]:
        ll = instr.split()
        ll[0] = ll[0].strip() + guess_specifier(i)
        tmp = " ".join(ll)
        instrs += [tmp.replace("KEF_SPEC_INT", i)]
    return instrs

def mutator_int_8(instr):
    instrs = []
    for i in ["$0x0", "$0xFF", "$0x80"]:
        instrs += [instr.replace("KEF_INT", i)]
    return instrs

def mutator_int_16(instr):
    instrs = []
    for i in ["$0x0", "$0xFFFF", "$0x8000"]:
        instrs += [instr.replace("KEF_INT", i)]
    return instrs

def mutator_int_32(instr):
    instrs = []
    for i in ["$0x0", "$0x80000000", "$0xFFFFFFFF"]:
        instrs += [instr.replace("KEF_INT", i)]
    return instrs

def mutator_prefix(instr):
    instr = instr.replace("KEF_PREFIX", "")
    instrs = [instr]
    for p in [".byte 0xf3", ".byte 0x66", ".byte 0xf0", ".word 0xf0f3"]:
        instrs += ["%s ; %s" % (p, instr)]
    return instrs

def mutator_datasz(instr):
    instr = instr.replace("KEF_DATASZ", "")
    instrs = [instr]
    for p in [".byte 0x66"]:
        instrs += ["%s ; %s" % (p, instr)]
    return instrs

def mutator_jump_ring3(instr):
    return [".byte 0xea; .long 0x00; .word 0x3b;"]

def mutator_enter_vm86(instr):
    return [".byte 0xea; .long 0x00; .word 0xd3;"]

def mutator_pt_base(instr):
    f = Elf(KERNEL)
    pt_base = f.getSymbol("pt").getAddress()
    instr = instr.replace("KEF_PT_BASE", u"$0x%.8x" % pt_base)
    return [instr]

def mutator_tcVM86_base(instr):
    f = Elf(KERNEL)
    pt_base = f.getSymbol("tc_ringvm_base").getAddress()
    instr = instr.replace("KEF_TCVM86_BASE", u"$0x%.8x" % pt_base)
    return [instr]

#### Ring-base mutators ####
def mutator_ringN_base(instr, n):
    f = Elf(KERNEL)
    ring_base = f.getSection(".tcring%d" % n).getLowAddr()
    instr = instr.replace("KEF_RING%d_BASE" % n, u"$0x%.8x" % ring_base)
    return [instr]

def mutator_ring0_base(instr):
    return mutator_ringN_base(instr, 0)
def mutator_ring1_base(instr):
    return mutator_ringN_base(instr, 1)
def mutator_ring2_base(instr):
    return mutator_ringN_base(instr, 2)
def mutator_ring3_base(instr):
    return mutator_ringN_base(instr, 3)
############################

#### Ring-selector mutators ####
def mutator_ringN_selector(instr, n, sel):
    name = "KEF_RING%d_%s" % (n, sel.upper())
    val  = RINGS_SELECTORS[n][sel]
    instr = instr.replace(name, u"$0x%.4x" % val)
    return [instr]

def mutator_ring0_cs(instr):
    return mutator_ringN_selector(instr, 0, 'cs')
def mutator_ring0_ds(instr):
    return mutator_ringN_selector(instr, 0, 'ds')
def mutator_ring0_ss(instr):
    return mutator_ringN_selector(instr, 0, 'ss')
def mutator_ring3_cs(instr):
    return mutator_ringN_selector(instr, 3, 'cs')
def mutator_ring3_ds(instr):
    return mutator_ringN_selector(instr, 3, 'ds')
def mutator_ring3_ss(instr):
    return mutator_ringN_selector(instr, 3, 'ss')
################################

def mutator_random_operand(instr):
    instrs = []

    m = re.search("KEF_RANDOM_OPERAND\((.*,.*)\)", instr)
    assert m, "KEF_RANDOM_OPERAND(op1, op2, op3)"
    ops = m.group(1).split(",")
    for op in ops:
        if op.startswith("\"") and op.endswith("\""):
            op = op.strip("\"")

            if "\\" in op:
                bb = []
                for c in op.split("\\x"):
                    if len(c) > 0:
                        bb.append("0x" + c)
                op = ".byte " + ", ".join(bb)

        i = re.sub("KEF_RANDOM_OPERAND\(.*\)", op, instr)
        instrs += [i]

    return instrs

def mutator_bits_32(instr):
    instrs = []
    for i in range(32):
        v = 2**i
        instrs += [instr.replace("KEF_BITS32", "$0x%.8x" % v)]
    return instrs

def mutator_mask_32(instr):
    instrs = []
    for i in range(32):
        v = ~(2**i) & 0xffffffff
        instrs += [instr.replace("KEF_MASK32", "$0x%.8x" % v)]
    return instrs

def mutator_bitno_32(instr):
    instrs = []
    for i in range(32):
        instrs += [instr.replace("KEF_BITNO32", "$0x%.2x" % i)]
    return instrs

MUTATORS = {
    'KEF_INT' : mutator_int,
    'KEF_SPEC_INT' : mutator_spec_int,
    'KEF_BITNO32' : mutator_bitno_32,
    'KEF_INT8' : mutator_int_8,
    'KEF_INT16' : mutator_int_16,
    'KEF_INT32' : mutator_int_32,
    'KEF_BITS32' : mutator_bits_32,
    'KEF_MASK32' : mutator_mask_32,
    'KEF_PREFIX' : mutator_prefix,
    'KEF_DATASZ' : mutator_datasz,
    'KEF_JUMP_RING3' : mutator_jump_ring3,
    'KEF_ENTER_VM86' : mutator_enter_vm86,
    'KEF_RING0_BASE': mutator_ring0_base,
    'KEF_RING0_CS': mutator_ring0_cs,
    'KEF_RING0_DS': mutator_ring0_ds,
    'KEF_RING0_SS': mutator_ring0_ss,
    'KEF_RING3_BASE': mutator_ring3_base,
    'KEF_RING3_CS': mutator_ring3_cs,
    'KEF_RING3_DS': mutator_ring3_ds,
    'KEF_RING3_SS': mutator_ring3_ss,
    'KEF_PT_BASE': mutator_pt_base,
    'KEF_RANDOM_OPERAND' : mutator_random_operand,
    'KEF_TCVM86_BASE':mutator_tcVM86_base,
}

# compute cartesian products
def product(*args, **kwds):
    # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
    pools = map(tuple, args) * kwds.get('repeat', 1)
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    for prod in result:
        yield tuple(prod)

# mutate and combine instructions (we assume there is at most one MUTATOR per
# instruction)
def mutate(rings):
    mutated_rings = ["", "", "", "", ""]

    for ring, instrs in rings.iteritems():
        mutated_instrs = []
        if instrs:
            for instr in filter(lambda i: i and i.replace(" ", ""), instrs.split("\n")):
                processed = False
                mutated_instr = []
                
                for key, mutator in MUTATORS.iteritems():
                    if key in instr:
                        # generate mutations
                        mutated_instr += mutator(instr)
                        processed = True
                
                # instruction with no mutator operand
                if not processed:
                    mutated_instr += [instr]

                mutated_instrs += [mutated_instr]

        # generate combinations (instructions)
        mutated_rings[ring] = ["\n".join(list(i)) for i in product(*mutated_instrs)]

    # generate combinations (rings)
    combined_mutated_rings = []
    for r in product(*mutated_rings):
        combined_mutated_rings += [{0: r[0], 1: r[1], 2: r[2], 3: r[3], 4: r[4]}]

    return combined_mutated_rings

# Compile a test-case
def assemble(code, directive=""):
    if not code:
        return ""

    # Dummy interrupt (iret) to touch the page containing interrupt handlers and to
    # prevent synthetic pagefaults in KVM (that might corrupt RF).
    # Disabled (better to handle with a normalization rule)
    # code = "int $0x21;\n" + code

    tmpobj = tempfile.mktemp(prefix = "kemufuzzer-tc-")

    # assemble
    cmdline = "as -32 -o %s -" % tmpobj
    p = subprocess.Popen(cmdline.split(), 
                         stdin = subprocess.PIPE, 
                         stdout = subprocess.PIPE, 
                         stderr = subprocess.PIPE)
    # str() was added to avoid unicode problems!
    prog = str("\n.text\n" + directive + "\n" + code + "\n")
    o, e = p.communicate(prog)

    if e is not None and e != "":
        print "[W] Can't compile mutation '%s'" % prog.replace("\n", " ")
        return ''
    # assert e is None or e == "", "---------\n" + code + "---------\n" + e
    
    # Strip the elf and keep only .text section
    cmdline = "objcopy -O binary -j .text %s" % tmpobj
    subprocess.call(cmdline.split())

    r = open(tmpobj).read()

    if os.path.isfile(tmpobj):
        os.unlink(tmpobj)

    return r

def randomize_rings(f):
    s = open(f, 'r').read()

    dom = xml.dom.minidom.parseString(s)
    testcase = dom.getElementsByTagName("testcase")[0]
    random_ring = testcase.getAttribute("random_ring")

    if len(random_ring) == 0:
        # No ring randomization
        return [s]

    random_ring = int(random_ring)

    rings = {}
    for i in range(4):
        ring = testcase.getElementsByTagName("ring%d" % i)
        if len(ring) > 0:
            rings[i] = ring[0]
    assert random_ring in rings, "[!] Can't randomize rings without the base one!"

    print "[*] Randomizing ring %d" % random_ring
    random_rings = set(range(4)) - set(rings.keys())
    random_rings.add(random_ring)

    ss = []
    for r in random_rings:
        tmp = s.replace("random_ring=\"%d\"" % random_ring, "")
        tmp = tmp.replace("ring=\"%d\"" % random_ring, "ring=\"%d\"" % r)
        tmp = tmp.replace("<ring%d>" % random_ring, "<ring%d>" % r)
        tmp = tmp.replace("</ring%d>" % random_ring, "</ring%d>" % r)

        ss.append(tmp)

    return ss

# Instantiate test-cases from a template
def tcinst(f):
    assert(f.endswith(".template"))
    print >> sys.stderr, "[*] Loading test-case template from '%s'" % (f)

    # HACK: we check for ring mutators before parsing the test-case string. In
    # this way, it is much easier :-)
    ss = randomize_rings(f)

    R = []
    for s in ss:
        # Parse testcase
        rings, start_ring = parse(s)

        # Mutate and combine test case
        testcases = mutate(rings)
        print >> sys.stderr, "[*] Mutating test-case (%d mutations)" % (len(testcases))

        TC = []
        for ring in testcases:
            d = {
                0: assemble(ring[0]), 1: assemble(ring[1]),
                2: assemble(ring[2]), 3: assemble(ring[3]),
                4: assemble(ring[4], ".code16")    # "ring 4" is reserved for code to be executed in vm86-mode
                 }

            # Check if we have assembled at least one instruction
            if set(d.values()) != set(['']):
                testcase = TestCase(d, start_ring)
                tmp = f.replace(".template", ".%.4d.testcase" % (len(R) + len(TC)))
                print >> sys.stderr, "[*] Saving mutated test-case '%s'" % tmp
                testcase.save(tmp)
                TC += [tmp]
        R.extend(TC)

    return R

def parse(s):
    def getText(nodelist):
        rc = ""
        for node in nodelist:
            if node.nodeType == node.TEXT_NODE:
                rc = rc + node.data

        return rc

    rings = {}
    dom = xml.dom.minidom.parseString(s)
    testcase = dom.getElementsByTagName("testcase")[0]
    start_ring = testcase.getAttribute("ring")
    if start_ring == "vm":
        print testcase
        start_ring = 4
    else:
        start_ring = int(start_ring)
        
    for i in range(4):
        ring = testcase.getElementsByTagName("ring%d" % i)
        if ring:
            rings[i] = getText(ring[0].childNodes)
        else:
            rings[i] = None

    # Check if there is some code to execute in vm86-mode
    ring = testcase.getElementsByTagName("ringvm")
    if ring:
        rings[4] = getText(ring[0].childNodes)
    else:
        rings[4] = None
    

    return rings, start_ring

###############################################################################

if __name__ == "__main__":
    for i in sys.argv[1:]:
        tcs = tcinst(i)
        print " ".join(tcs)
        print

    exit(0)
